import sympy
import re
from sympy.codegen.ast import Expr
POW_RE = re.compile(r'pow\((?P<base>[^,]*)\s*,\s*(?P<exp>[^\)]*)\)')
def expandpow(m):
  base = m['base']
  exp = int(m['exp'])
  return '*'.join(['(' + base + ')'] * exp)
def rdot(i0, i1, i2, i3):
  if (i0 < i2 or i0 == i2 and i1 < i3):
    return sympy.symbols('r%d%d%d%d' % (i0, i1, i2, i3))
  else:
    return sympy.symbols('r%d%d%d%d' % (i2, i3, i0, i1))
def generate_shake_series(natoms, constraints, name='shake', nodt = False):
  args = []

  args.append('celldata_t *cell')
  args.append('int i')
  if not nodt:
    args.append('real dtfsq, real tol, int maxiter')
  else:
    args.append('real tol, int maxiter')
  indent = ' ' * len('void %s(' % name)
  print('void %s(%s){' % (name, (', ').join(args)))
  print('  int ig0 = i;')
  for i in range(1, natoms):
    print('  int ig%d = cell->bonded_id[cell->first_bonded[i] + cell->shake[i].idx[%d]];' % (i, i - 1))
  for i in range(natoms):
    print('  real rmass%d = cell->rmass[ig%d];' % (i, i))
  for ik, [k1, k2] in enumerate(constraints):
    print('  real d%d%dsq = cell->shake[i].rsq[%d];' % (k1, k2, ik))
  r = [[sympy.symbols('r%d%d' % (i, j)) for j in range(natoms)] for i in range(natoms)]
  d = [[sympy.symbols('d%d%d' % (i, j)) for j in range(natoms)] for i in range(natoms)]
  ruc = [[sympy.symbols('ruc%d%d' % (i, j)) for j in range(natoms)] for i in range(natoms)]
  m = [sympy.symbols('m%s' % (i)) for i in range(natoms)]
  rm = [sympy.symbols('rmass%s' % (i)) for i in range(natoms)]
  l = [sympy.symbols('l%s' % (i)) for i in range(len(constraints))]
  A = sympy.Matrix([[0 for i in range(len(constraints))] for j in range(len(constraints))])
  dt = sympy.symbols('\\varDelta{t}')
  rs = map(lambda k: 'r%d%d, ruc%d%d' % (k[0], k[1], k[0], k[1]), constraints)
  
  print('  vec<real> ' + ', '.join(rs) + ';')
  for ik, [k1, k2] in enumerate(constraints):
    print('  r%d%d = cell->x[ig%d] - cell->x[ig%d];' % (k1, k2, k1, k2))
  
  for ik, [k1, k2] in enumerate(constraints):
    if not nodt:
      print('  ruc%d%d = cell->shake_xuc[ig%d] - cell->shake_xuc[ig%d];' % (k1, k2, k1, k2))
    else:
      print('  ruc%d%d = cell->x[ig%d] - cell->x[ig%d];' % (k1, k2, k1, k2))
  for ik, [k1, k2] in enumerate(constraints):
    print('  real ruc%d%dsq = ruc%d%d.norm();' % (k1, k2, k1, k2))
  rdots = []
  for ikk, [ikk1, ikk2] in enumerate(constraints):
    for jkk, [jkk1, jkk2] in enumerate(constraints):
      if jkk < ikk:
        continue
      print('  real r%d%d%d%d = r%d%d.dot(r%d%d);' % (ikk1, ikk2, jkk1, jkk2, ikk1, ikk2, jkk1, jkk2))
  
  for ik, [k1, k2] in enumerate(constraints):
    for ikk, [kk1, kk2] in enumerate(constraints):
      elem = 0
      delta11 = int(k1 == kk1)
      delta12 = int(k1 == kk2)
      delta21 = int(k2 == kk1)
      delta22 = int(k2 == kk2)
      elem = ((delta11 - delta12) / m[k1] + (delta22 - delta21) / m[k2]) * r[kk1][kk2] * ruc[k1][k2]
      expr = ['  real a%d%d = 2 * r%d%d.dot(ruc%d%d)*(' % (ik, ikk, k1, k2, kk1, kk2)]
      if delta11 - delta12 == -1:
        expr.append('-rmass%d' % k1)
      elif delta11 - delta12 == 1:
        expr.append('+rmass%d' % k1)
      if delta22 - delta21 == -1:
        expr.append('-rmass%d' % k2)
      elif delta22 - delta21 == 1:
        expr.append('+rmass%d' % k2)
      expr.append(');')
      print(''.join(expr))
  mat = sympy.Matrix([[sympy.symbols('a%d%d' % (i, j)) for j in range(len(constraints))] for i in range(len(constraints))])
  matdet = mat.det()
  sympy.init_printing()
  print('  real invmatdet = 1/(' + sympy.printing.ccode(matdet) + ');')
  matinvdet = sympy.simplify(mat.inv() * matdet)
  for i in range(mat.shape[0]):
    for j in range(mat.shape[1]):
      print('  real a%d%dinv = (' % (i, j) + sympy.printing.ccode(matinvdet[i, j]) + ') * invmatdet;')
  for ik, [k1, k2] in enumerate(constraints):
    quad = 0
    for ikk, [ikk1, ikk2] in enumerate(constraints):
      for jkk, [jkk1, jkk2] in enumerate(constraints):
        if (jkk < ikk):
          continue
        deltai11 = int(k1 == ikk1)
        deltai12 = int(k1 == ikk2)
        deltai21 = int(k2 == ikk1)
        deltai22 = int(k2 == ikk2)
        deltaj11 = int(k1 == jkk1)
        deltaj12 = int(k1 == jkk2)
        deltaj21 = int(k2 == jkk1)
        deltaj22 = int(k2 == jkk2)
        quad = ((deltai11 - deltai12)*rm[k1] + (deltai22 - deltai21)*rm[k2])*((deltaj11 - deltaj12)*rm[k1] + (deltaj22 - deltaj21)*rm[k2]) * rdot(ikk1, ikk2, jkk1, jkk2)# * l[ikk] * l[jkk]
        #print(quad)
        if (jkk > ikk):
          quad = quad * 2
        quad_rawc = sympy.printing.ccode(quad)
        quad_nopow = POW_RE.sub(expandpow, quad_rawc)
        
        print('  real quad%d_%d%d = ' % (ik, ikk, jkk)+ quad_nopow + ';')
  for ik in range(len(constraints)):
    print('  real l%d = 0;' % ik)
  print('  int done = 0;');
  print('  for (int i = 0; i < maxiter; i ++) {')
  for ik, [k1, k2] in enumerate(constraints):
    quadterms = []
    for ikk, [ikk1, ikk2] in enumerate(constraints):
      for jkk, [jkk1, jkk2] in enumerate(constraints):
        if (jkk < ikk):
          continue
        quadterms.append('quad%d_%d%d*l%d*l%d' % (ik, ikk, jkk, ikk, jkk))
    print('    real quad%d = ' % ik + ' + '.join(quadterms) + ';')
  for ik, [k1, k2] in enumerate(constraints):
    print('    real b%d = d%d%dsq - ruc%d%dsq - quad%d;' % (ik, k1, k2, k1, k2, ik))
  for ik in range(len(constraints)):
    lterms = []
    for jk in range(len(constraints)):
      lterms.append('a%d%dinv*b%d' % (ik, jk, jk))
    print('    real l%d_new = %s;' % (ik, ' + '.join(lterms)))
  done = ' && '.join(map(lambda ik: 'fabs(l%d_new - l%d) < tol' % (ik, ik), range(len(constraints))))
  print('    int almost_done = ' + done + ';')
  for ik in range(len(constraints)):
    print('    l%d = l%d_new;' % (ik, ik))
  print('    if (done) break;')
  print('    done = almost_done;')
  print('  }')
  if not nodt:
    for ik in range(len(constraints)):
      print('  l%d = l%d / dtfsq;' % (ik, ik))
  for i in range(natoms):
    for ik, [k1, k2] in enumerate(constraints):
      if i == k1:
        sign = '+'
      elif i == k2:
        sign = '-'
      else:
        sign = None
      if sign is not None:
        if not nodt:
          print('  cell->f[ig%d] %s= r%d%d * l%d;' % (i, sign, k1, k2, ik))
        else:
          print('  cell->shake_xuc[ig%d] %s= r%d%d * (l%d * rmass%d);' % (i, sign, k1, k2, ik, i))
  print('}')
#generate_shake_series(2, [[0, 1]], 'shake2')
generate_shake_series(3, [[0, 1], [0, 2]], 'shake3')
generate_shake_series(3, [[0, 1], [0, 2], [1, 2]], 'shake3angle')
generate_shake_series(4, [[0, 1], [0, 2], [0, 3]], 'shake4')
generate_shake_series(3, [[0, 1], [0, 2]], 'shake3_nodt', nodt=True)
generate_shake_series(3, [[0, 1], [0, 2], [1, 2]], 'shake3angle_nodt', nodt=True)
generate_shake_series(4, [[0, 1], [0, 2], [0, 3]], 'shake4_nodt', nodt=True)